use super::{current_worker, Runtime};
use crate::Error;
use core::alloc::Layout;
use core::cell::UnsafeCell;
use core::mem::{self, ManuallyDrop};
use core::ptr;
use core::ptr::NonNull;
use core::slice;
use core::sync::atomic::{
    AtomicPtr,
    Ordering::{Acquire, Relaxed, Release},
};
use hicollections::{List, ListNode};
use hipool::{Allocator, Boxed, MemPool, NullAlloc, PoolAlloc};

const PAGE_SIZE: usize = 4096;
const PAGE_MASK: usize = PAGE_SIZE - 1;
const MAX_SIZE: usize = PAGE_SIZE - mem::size_of::<MemPage>();
const HALF_SIZE: usize = MAX_SIZE / 2;
const HALF_SIZE_1: usize = HALF_SIZE + 1;

unsafe impl<'a> Allocator for &'a MemCache {
    unsafe fn alloc_buf<F>(&self, layout: Layout, f: F) -> Result<NonNull<[u8]>, Error>
    where
        F: FnOnce(NonNull<[u8]>) -> Result<(), Error>,
    {
        debug_assert!(current_worker().unwrap().group_id() == self.group);
        debug_assert!(current_worker().unwrap().worker_id() == self.worker);
        let inner = unsafe { &mut *self.inner.get() };
        if layout.size() <= MAX_SIZE {
            let buf = match layout.size() {
                0..=16 => inner.cache16.alloc(self.group, self.worker),
                17..=32 => inner.cache32.alloc(self.group, self.worker),
                33..=64 => inner.cache64.alloc(self.group, self.worker),
                65..=128 => inner.cache128.alloc(self.group, self.worker),
                129..=256 => inner.cache256.alloc(self.group, self.worker),
                257..=512 => inner.cache512.alloc(self.group, self.worker),
                513..=1024 => inner.cache1024.alloc(self.group, self.worker),
                1025..=HALF_SIZE => inner.cache_half.alloc(self.group, self.worker),
                _ => inner.cache_full.alloc(self.group, self.worker),
            }?
            .cast::<u8>();
            let slice = unsafe { slice::from_raw_parts(buf.as_ptr(), layout.size()) };
            let buf = NonNull::from(slice);
            f(buf)?;
            Ok(buf)
        } else {
            PoolAlloc.alloc_buf(layout, f)
        }
    }
    unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: Layout) {
        if layout.size() <= MAX_SIZE {
            let ptr = ptr.cast::<()>();
            let (group, id) = self.group_worker(ptr, layout);
            if let Some(worker) = current_worker() {
                if worker.group_id() == group && worker.worker_id() == id {
                    return self.free(ptr, layout);
                }
            }
            Runtime::get(group).task_cache(id).remote_free(ptr, layout);
        } else {
            PoolAlloc.deallocate(ptr, layout);
        }
    }
}

pub(crate) struct MemCache {
    inner: UnsafeCell<Inner>,
    group: u8,
    worker: u16,
}

unsafe impl Send for MemCache {}

#[derive(Default)]
struct Inner {
    cache16: BlockCache<16>,
    cache32: BlockCache<32>,
    cache64: BlockCache<64>,
    cache128: BlockCache<128>,
    cache256: BlockCache<256>,
    cache512: BlockCache<512>,
    cache1024: BlockCache<1024>,
    cache_half: BlockCache<HALF_SIZE>,
    cache_full: BlockCache<MAX_SIZE>,
}

pub(crate) type BoxedMemCache = Boxed<'static, MemCache, NullAlloc>;

type Pool = &'static MemPool;

impl MemCache {
    pub(crate) fn new(group: u8, worker: u16) -> Self {
        Self {
            inner: UnsafeCell::new(Inner::default()),
            group,
            worker,
        }
    }

    pub(crate) fn new_in(pool: Pool, group: u8, worker: u16) -> Result<BoxedMemCache, Error> {
        Boxed::new_in(pool, Self::new(group, worker)).map(|boxed| boxed.into())
    }

    pub(crate) fn clean(&self) {
        let inner = unsafe { &mut *self.inner.get() };
        inner.cache16.clean();
        inner.cache32.clean();
        inner.cache64.clean();
        inner.cache128.clean();
        inner.cache256.clean();
        inner.cache512.clean();
        inner.cache1024.clean();
        inner.cache_half.clean();
        inner.cache_full.clean();
    }

    unsafe fn free(&self, ptr: NonNull<()>, layout: Layout) {
        let inner = unsafe { &mut *self.inner.get() };
        match layout.size() {
            0..=16 => inner.cache16.free(ptr),
            17..=32 => inner.cache32.free(ptr),
            33..=64 => inner.cache64.free(ptr),
            65..=128 => inner.cache128.free(ptr),
            129..=256 => inner.cache256.free(ptr),
            257..=512 => inner.cache512.free(ptr),
            513..=1024 => inner.cache1024.free(ptr),
            1025..=HALF_SIZE => inner.cache_half.free(ptr),
            HALF_SIZE_1..=MAX_SIZE => inner.cache_full.free(ptr),
            _ => unreachable!("error size"),
        }
    }
    unsafe fn remote_free(&self, ptr: NonNull<()>, layout: Layout) {
        let inner = unsafe { &*self.inner.get() };
        match layout.size() {
            0..=16 => inner.cache16.remote_free(ptr),
            17..=32 => inner.cache32.remote_free(ptr),
            33..=64 => inner.cache64.remote_free(ptr),
            65..=128 => inner.cache128.remote_free(ptr),
            129..=256 => inner.cache256.remote_free(ptr),
            257..=512 => inner.cache512.remote_free(ptr),
            513..=1024 => inner.cache1024.remote_free(ptr),
            1025..=HALF_SIZE => inner.cache_half.remote_free(ptr),
            HALF_SIZE_1..=MAX_SIZE => inner.cache_full.remote_free(ptr),
            _ => unreachable!("error size"),
        }
    }
    unsafe fn group_worker(&self, ptr: NonNull<()>, layout: Layout) -> (u8, u16) {
        let inner = unsafe { &*self.inner.get() };
        match layout.size() {
            0..=16 => inner.cache16.group_worker(ptr),
            17..=32 => inner.cache32.group_worker(ptr),
            33..=64 => inner.cache64.group_worker(ptr),
            65..=128 => inner.cache128.group_worker(ptr),
            129..=256 => inner.cache256.group_worker(ptr),
            257..=512 => inner.cache512.group_worker(ptr),
            513..=1024 => inner.cache1024.group_worker(ptr),
            1025..=HALF_SIZE => inner.cache_half.group_worker(ptr),
            HALF_SIZE_1..=MAX_SIZE => inner.cache_full.group_worker(ptr),
            _ => unreachable!("error size"),
        }
    }
}

struct BlockCache<const SIZE: usize> {
    cached: *mut MemBlock,
    pages: ManuallyDrop<List<MemPage>>,
    remote: AtomicPtr<MemBlock>,
}

struct MemBlock {
    next: *mut MemBlock,
}

#[allow(non_camel_case_types)]
#[cfg(target_pointer_width = "32")]
type off_t = u8;

#[allow(non_camel_case_types)]
#[cfg(target_pointer_width = "64")]
type off_t = u16;

#[repr(C)]
struct MemPage {
    worker: u16,
    group: u8,
    cnt: off_t,
    node: ListNode,
    cached: *mut MemBlock,
}

fn page_list_new() -> List<MemPage> {
    List::<MemPage>::new(|page| unsafe { ptr::addr_of!((*page).node) })
}

impl<const SIZE: usize> Default for BlockCache<SIZE> {
    fn default() -> Self {
        Self::new()
    }
}

impl<const SIZE: usize> BlockCache<SIZE> {
    fn new() -> Self {
        Self {
            cached: ptr::null_mut(),
            pages: ManuallyDrop::new(page_list_new()),
            remote: AtomicPtr::new(ptr::null_mut()),
        }
    }

    fn alloc(&mut self, group: u8, worker: u16) -> Result<NonNull<()>, Error> {
        loop {
            if !self.cached.is_null() {
                let block = unsafe { &*self.cached };
                self.cached = block.next;
                return Ok(NonNull::from(block).cast::<()>());
            }
            let block = self.remote.swap(ptr::null_mut(), Acquire);
            if !block.is_null() {
                self.cached = block;
            } else if !self.reclaim_pages() {
                self.new_page(group, worker)?;
            }
        }
    }

    unsafe fn free(&mut self, ptr: NonNull<()>) {
        let block = ptr.cast::<MemBlock>().as_ptr();
        unsafe { block.write(MemBlock { next: self.cached }) };
        self.cached = block;
    }

    unsafe fn remote_free(&self, ptr: NonNull<()>) {
        let block = ptr.cast::<MemBlock>().as_mut();
        let mut next = self.remote.load(Relaxed);
        loop {
            block.next = next;
            match self
                .remote
                .compare_exchange_weak(block.next, block, Release, Relaxed)
            {
                Ok(_) => return,
                Err(old) => next = old,
            }
        }
    }

    fn clean(&mut self) {
        loop {
            while !self.cached.is_null() {
                let mut block = unsafe { NonNull::new_unchecked(self.cached) };
                self.cached = unsafe { block.as_ref() }.next;
                let page = unsafe { self.page(block.cast::<()>()).as_mut() };
                unsafe { self.pages.del(page) };
                if (page.cnt as usize) < (Self::block_cnt() - 1) {
                    page.cnt += 1;
                    unsafe { block.as_mut() }.next = page.cached;
                    page.cached = block.as_ptr();
                    unsafe { self.pages.add_head(page) };
                } else {
                    unsafe { self.drop_page(NonNull::from(page)) };
                }
            }
            let block = self.remote.swap(ptr::null_mut(), Acquire);
            if block.is_null() {
                return;
            }
            self.cached = block;
        }
    }

    unsafe fn group_worker(&self, ptr: NonNull<()>) -> (u8, u16) {
        let page = unsafe { &*self.page(ptr).as_ptr() };
        (page.group, page.worker)
    }
}

impl<const SIZE: usize> BlockCache<SIZE> {
    fn new_page(&mut self, group: u8, worker: u16) -> Result<(), Error> {
        const LAYOUT: Layout = unsafe { Layout::from_size_align_unchecked(PAGE_SIZE, PAGE_SIZE) };
        let cnt = Self::block_cnt();
        unsafe {
            let base = PoolAlloc.allocate(LAYOUT)?.cast::<u8>().as_ptr();
            let mut block = base.cast::<MemBlock>();
            for _ in 0..cnt - 1 {
                let next = block.cast::<u8>().add(block_size(SIZE)).cast::<MemBlock>();
                block.write(MemBlock { next });
                block = next;
            }
            let last = block;
            last.write(MemBlock { next: self.cached });
            self.cached = base.cast::<MemBlock>();
            let page = last.cast::<u8>().add(block_size(SIZE)).cast::<MemPage>();
            page.write(MemPage::new(group, worker));
            self.pages.add_tail(&*page);
        }
        Ok(())
    }

    const fn block_cnt() -> usize {
        let cnt = MAX_SIZE / block_size(SIZE);
        if cnt > off_t::MAX as usize {
            return off_t::MAX as usize;
        }
        cnt
    }

    fn reclaim_pages(&mut self) -> bool {
        if let Some(page) = self.pages.first() {
            if !page.cached.is_null() {
                self.cached = page.cached;
                let page_mut = unsafe { &mut *NonNull::from(page).as_ptr() };
                page_mut.cached = ptr::null_mut();
                page_mut.cnt = 0;
                unsafe { self.pages.del(page) };
                unsafe { self.pages.add_tail(page) };
                return true;
            }
        }
        false
    }

    unsafe fn page(&self, ptr: NonNull<()>) -> NonNull<MemPage> {
        let off = self.offset(ptr);
        let page = ptr
            .cast::<u8>()
            .as_ptr()
            .add(off * block_size(SIZE))
            .cast::<MemPage>();
        NonNull::new_unchecked(page)
    }

    fn offset(&self, ptr: NonNull<()>) -> usize {
        let size = block_size(SIZE);
        let cnt = MAX_SIZE / size;
        let off = ptr.as_ptr() as usize & PAGE_MASK;
        cnt - (off / size)
    }

    unsafe fn drop_page(&self, page: NonNull<MemPage>) {
        const LAYOUT: Layout = unsafe { Layout::from_size_align_unchecked(PAGE_SIZE, PAGE_SIZE) };
        let slice = slice::from_raw_parts(
            page.cast::<u8>().as_ptr().sub(Self::block_cnt() * SIZE),
            PAGE_SIZE,
        );
        PoolAlloc.deallocate(NonNull::from(slice), LAYOUT);
    }
}

impl<const SIZE: usize> Drop for BlockCache<SIZE> {
    fn drop(&mut self) {
        for page in self.pages.iter() {
            unsafe { self.drop_page(NonNull::from(page)) };
        }
    }
}

impl MemPage {
    fn new(group: u8, worker: u16) -> MemPage {
        Self {
            cnt: 0,
            group,
            worker,
            node: ListNode::new(),
            cached: ptr::null_mut(),
        }
    }
}

const fn block_size(size: usize) -> usize {
    if size < mem::size_of::<MemBlock>() {
        return mem::size_of::<MemBlock>();
    }
    size
}
